import socket
import threading
import time

from .config import FROM_UDP_PORT, FROM_TCP_PORT
from protocol_utils import *
from protocol_utils.config import *


class Connection(object):
    def __init__(self, quiet=False, show=False, from_addr=("127.0.0.1", FROM_UDP_PORT, FROM_TCP_PORT),
                 **conf):
        self.from_addr: Tuple[str, int, int] = from_addr
        self.expiration: int = conf.get('expiration', EXPIRATION)
        self.private_key: datatypes.PrivateKey = conf.get('private_key', None)
        self.socket: socket = None
        self.serve_thread = None
        self.quiet = quiet

        self.sent = dict()
        self.ping_count = 0
        self.pong_count = 0

        if show and not quiet:
            print("public")
            print(self.private_key.public_key.to_address())
            print("private")
            print(self.private_key)

    def open(self):
        self.socket = socket.socket(socket.AF_INET,
                                    socket.SOCK_DGRAM)
        self.socket.settimeout(self.expiration)
        self.socket.bind(('', self.from_addr[1]))
        self.serve_thread = threading.Thread(target=self.serve_forever)
        self.serve_thread.start()

    def close(self):
        if self.socket:
            sock = self.socket
            self.socket = None
            try:
                sock.shutdown(socket.SHUT_RDWR)
            except socket.error:
                pass
            finally:
                sock.close()

    def send(self, remote: Tuple[str, int, int], cmd_id: int, payload: Sequence[Any]):
        message = pack_v4(cmd_id, payload, self.private_key)
        try:
            self.socket.sendto(message, (remote[0], remote[1]))
        except Exception:
            pass
        return message

    def _get_msg_expiration(self):
        return rlp.sedes.big_endian_int.serialize(int(time.time() + self.expiration))

    def serve_forever(self) -> None:
        while self.socket:
            try:
                datagram, (ip_address, port) = self.socket.recvfrom(DISCOVERY_DATAGRAM_BUFFER_SIZE)
                threading.Thread(target=self.consume_datagram, args=[(ip_address, port, port), datagram]).run()
            except TimeoutError:
                self.close()
                exit(0)
            except Exception:
                pass

    def consume_datagram(self, address: Tuple[str, int, int], datagram):
        self.handle_msg(address, datagram)

    def handle_msg(self, _, message: bytes) -> None:
        try:
            remote_pubkey, cmd_id, payload, message_hash = unpack_v4(message)
        except SyntaxError:
            return
        if cmd_id in (CMD_PING, CMD_FIND_NODE, CMD_NEIGHBOURS):
            return
        elif cmd_id == CMD_PONG:
            return self.recv_pong(payload)
        else:
            raise ValueError(f"Unknown command id: {cmd_id}")

    def ping(self, ip: str, udp: int, tcp: int):

        if not self.quiet:
            pass
            if self.pong_count < 0.25 * self.ping_count and self.ping_count % 10 == 0:
                print(
                    f"{self.ping_count} pings sent so far, {self.pong_count} pongs received ({round(((self.ping_count - self.pong_count) / self.ping_count) * 100, 2)}% unanswered)")

        remote = (ip, udp, tcp)
        version = rlp.sedes.big_endian_int.serialize(PROTOCOL_VERSION)
        expiration = self._get_msg_expiration()
        this_address = [ipaddress.ip_address(self.from_addr[0]).packed, enc_port(self.from_addr[1]),
                        enc_port(self.from_addr[2])]
        node_address = [ipaddress.ip_address(remote[0]).packed, enc_port(remote[1]), enc_port(remote[2])]
        payload = (version, this_address, node_address, expiration)
        message = self.send(remote, CMD_PING, payload)
        token = Hash32(message[:MAC_SIZE])

        self.sent[token] = (self.ping_count, datetime.now())
        self.ping_count += 1
        return token

    def recv_pong(self, payload: Sequence[Any]) -> None:
        # The pong payload should have at least 3 elements: to, token, expiration
        if len(payload) < 3:
            return
        else:
            _, token, expiration = payload[:3]

        ping_number, ping_time = self.sent.pop(token)
        self.pong_count += 1

        if not self.quiet:
            delta = (datetime.now() - ping_time)
            print(
                f"got a pong after {delta.seconds}.{delta.microseconds}s corresponding to ping {ping_number}")
